#ifndef STATSxx_DBN__DBN_HPP
#define STATSxx_DBN__DBN_HPP

// STL
#include <string> // std::string
#include <vector> // std::vector<>

// jScience
#include "jScience/linalg.hpp" // Matrix<>

// stats++
#include "statsxx/machine_learning/DBN.hpp" // machine_learning::DBN


machine_learning::DBN create_DBN(
                                 const bool              create,
                                 // -----
                                 const std::vector<int> &create_nunits,
                                 const std::vector<int> &create_units_type,
                                 // -----
                                 const std::string       dbn_file
                                 );


machine_learning::DBN train_DBN(
                                machine_learning::DBN        dbn,
                                // =====
                                const Matrix<double>        &X,
                                // =====
                                const int                    train_nepoch,
                                const int                    train_npts_per_batch,
                                // -----
                                const int                    train_K_start,
                                const int                    train_K_end,
                                const double                 train_K_rate,
                                // -----
                                const bool                   train_sample_v,
                                const bool                   train_sample_hdata,
                                // -----
                                const double                 train_lr,
                                const std::vector<double>    train_lr_dot,
                                const std::vector<double>    train_lr_adj,
                                const double                 train_lr_min,
                                const double                 train_lr_max,
                                // -----
                                const double                 train_p_start,
                                const double                 train_p_end,
                                const double                 train_p_rate,
                                const std::vector<double>    train_p_dot,
                                const std::vector<double>    train_p_adj,
                                // -----
                                const double                 train_w_penalty,
                                // -----
                                //    int free_energy_npts;
                                //    int free_energy_nepoch;
                                //    int free_energy_window;
                                // -----
                                const int                    train_convg_nno_improvement,
                                const double                 train_convg_max_inc_max_w,
                                // =====
                                const std::string            dbn_file
                                );


void test_DBN(
              const machine_learning::DBN &dbn,
              // -----
              const Matrix<double>        &X,
              // -----
              const std::string            H_file
              );


#endif
